Writing Functions in Python

useful tricks, like how to write context managers and decorators
functions
python
Author

Victor Omondi

Published

July 26, 2020

Best Practices

Best practices when writing functions. We’ll cover docstrings and why they matter and how to know when you need to turn a chunk of code into a function. We will also code how Python passes arguments to functions, as well as some common gotchas that can cause debugging headaches when calling functions.

Docstrings

Crafting a docstring

The first function is count_letter(). It takes a string and a single letter and returns the number of times the letter appears in the string. We want the users of our open-source package to be able to understand how this function works easily, so we will need to give it a docstring.

import inspect
import pandas as pd
import numpy as np
import time
import contextlib
import os
def count_letter(content, letter):
    """Count the number of times `letter` appears in `content`.
    
    Args:
        content (str): The string to search.
        letter (str): The letter to search for.
        
    Returns:
        int
    Raises:
        ValueError: If `letter` is not a one-character string.
    """
    if (not isinstance(letter, str)) or len(letter) != 1:
        raise ValueError('`letter` must be a single character string.')
    return len([char for char in content if char == letter])
help(count_letter)
Help on function count_letter in module __main__:

count_letter(content, letter)
    Count the number of times `letter` appears in `content`.
    
    Args:
        content (str): The string to search.
        letter (str): The letter to search for.
        
    Returns:
        int
    Raises:
        ValueError: If `letter` is not a one-character string.

Retrieving docstrings

My friends and I are working on building an amazing new Python IDE (integrated development environment – like PyCharm, Spyder, Eclipse, Visual Studio, etc.). Our team wants to add a feature that displays a tooltip with a function’s docstring whenever the user starts typing the function name. That way, the user doesn’t have to go elsewhere to look up the documentation for the function they are trying to use. We’ve been asked to complete the build_tooltip() function that retrieves a docstring from an arbitrary function.

# Get the docstring with an attribute of count_letter()
docstring = count_letter.__doc__

border = '#' * 28
print('{}\n{}\n{}'.format(border, docstring, border))
############################
Count the number of times `letter` appears in `content`.
    
    Args:
        content (str): The string to search.
        letter (str): The letter to search for.
        
    Returns:
        int
    Raises:
        ValueError: If `letter` is not a one-character string.
    
############################
# Get the docstring with a function from the inspect module
docstring = inspect.getdoc(count_letter)

border = '#' * 28
print('{}\n{}\n{}'.format(border, docstring, border))
############################
Count the number of times `letter` appears in `content`.

Args:
    content (str): The string to search.
    letter (str): The letter to search for.
    
Returns:
    int
Raises:
    ValueError: If `letter` is not a one-character string.
############################
def build_tooltip(function):
    """Create a tooltip for any function that shows the function's docstring.
    
    Args:
        function (callable): The function we want a tooltip for.
    
    Returns:
        str
    """
    # Use 'inspect' to get the docstring
    docstring = inspect.getdoc(function)
    border = '#' * 28
    
    return '{}\n{}\n{}'.format(border, docstring, border)

print(build_tooltip(count_letter))
print(build_tooltip(range))
print(build_tooltip(print))
############################
Count the number of times `letter` appears in `content`.

Args:
    content (str): The string to search.
    letter (str): The letter to search for.
    
Returns:
    int
Raises:
    ValueError: If `letter` is not a one-character string.
############################
############################
range(stop) -> range object
range(start, stop[, step]) -> range object

Return an object that produces a sequence of integers from start (inclusive)
to stop (exclusive) by step.  range(i, j) produces i, i+1, i+2, ..., j-1.
start defaults to 0, and stop is omitted!  range(4) produces 0, 1, 2, 3.
These are exactly the valid indices for a list of 4 elements.
When step is given, it specifies the increment (or decrement).
############################
############################
print(value, ..., sep=' ', end='\n', file=sys.stdout, flush=False)

Prints the values to a stream, or to sys.stdout by default.
Optional keyword arguments:
file:  a file-like object (stream); defaults to the current sys.stdout.
sep:   string inserted between values, default a space.
end:   string appended after the last value, default a newline.
flush: whether to forcibly flush the stream.
############################

DRY and “Do one thing”

While we were developing a model to predict the likelihood of a student graduating from college, we wrote this bit of code to get the z-scores of students’ yearly GPAs. Now we’re ready to turn it into a production-quality system, so we need to do something about the repetition. Writing a function to calculate the z-scores would improve this code.

# Standardize the GPAs for each year
df['y1_z'] = (df.y1_gpa - df.y1_gpa.mean()) / df.y1_gpa.std()
df['y2_z'] = (df.y2_gpa - df.y2_gpa.mean()) / df.y2_gpa.std()
df['y3_z'] = (df.y3_gpa - df.y3_gpa.mean()) / df.y3_gpa.std()
df['y4_z'] = (df.y4_gpa - df.y4_gpa.mean()) / df.y4_gpa.std()

Note: df is a pandas DataFrame where each row is a student with 4 columns of yearly student GPAs: y1_gpa, y2_gpa, y3_gpa, y4_gpa

df = pd.read_csv('students.csv', index_col=0)
df.head()
y1_gpa y2_gpa y3_gpa y4_gpa
0 2.785877 2.052513 2.170544 0.065570
1 1.144557 2.666498 0.267098 2.884737
2 0.907406 0.423634 2.613459 0.030950
3 2.205259 0.523580 3.984345 0.339289
4 2.877876 1.287922 3.077589 0.901994
def standardize(column):
    """Standardize the values in a column.
    
    Args:
        column (pandas Series): The data to standardize.
        
    Returns:
        pandas Series: the values as z-scores
    """
    # Finish the function so that it returns the z-scores
    z_score = (df[column] - df[column].mean()) / df[column].std()
    
    return z_score

# Use the standardize() function to calculate the z-scores
df['y1_z'] = standardize("y1_gpa")
df['y2_z'] = standardize("y2_gpa")
df['y3_z'] = standardize("y3_gpa")
df['y4_z'] = standardize("y4_gpa")

df.head()
y1_gpa y2_gpa y3_gpa y4_gpa y1_z y2_z y3_z y4_z
0 2.785877 2.052513 2.170544 0.065570 0.790863 0.028021 0.172322 -1.711179
1 1.144557 2.666498 0.267098 2.884737 -0.872971 0.564636 -1.347122 0.824431
2 0.907406 0.423634 2.613459 0.030950 -1.113376 -1.395595 0.525883 -1.742317
3 2.205259 0.523580 3.984345 0.339289 0.202281 -1.308243 1.620206 -1.464991
4 2.877876 1.287922 3.077589 0.901994 0.884124 -0.640219 0.896379 -0.958885

standardize() will probably be useful in other places in our code, and now it is easy to use, test, and update if we need to. It’s also easier to tell what the code is doing because of the docstring and the name of the function.

Split up a function

Another engineer on our team has written this function to calculate the mean and median of a list. We want to show them how to split it into two simpler functions: mean() and median()

def mean_and_median(values):
  """Get the mean and median of a list of `values`

  Args:
    values (iterable of float): A list of numbers

  Returns:
    tuple (float, float): The mean and median
  """
  mean = sum(values) / len(values)
  midpoint = int(len(values) / 2)
  if len(values) % 2 == 0:
    median = (values[midpoint - 1] + values[midpoint]) / 2
  else:
    median = values[midpoint]

  return mean, median
def mean(values):
    """Get the mean of a list of values

    Args:
        values (iterable of float): A list of numbers

    Returns:
        float
    """
    # Write the mean() function
    mean = sum(values) / len(values)
    
    return mean
def median(values):
    """Get the median of a list of values

    Args:
        values (iterable of float): A list of numbers

    Returns:
        float
    """
    # Write the median() function
    midpoint = int(len(values) /2)
    if len(values) % 2 == 0:
        median = (values[midpoint-1] + values[midpoint]) / 2
    else:
        median = values[midpoint]
    return median

Each function does one thing and does it well. Using, testing, and maintaining these will be a breeze (although we’ll probably just use numpy.mean() and numpy.median() for this in real life).

Pass by assignment

Best practice for default arguments

One of our co-workers has written this function for adding a column to a panda’s DataFrame. Unfortunately, they used a mutable variable as a default argument value!

def add_column(values, df=pandas.DataFrame()):
  """Add a column of `values` to a DataFrame `df`.
  The column will be named "col_<n>" where "n" is
  the numerical index of the column.

  Args:
    values (iterable): The values of the new column
    df (DataFrame, optional): The DataFrame to update.
      If no DataFrame is passed, one is created by default.

  Returns:
    DataFrame
  """
  df['col_{}'.format(len(df.columns))] = values
  return df
# Use an immutable variable for the default argument 
def better_add_column(values, df=None):
    """Add a column of `values` to a DataFrame `df`. The column will be named "col_<n>" where "n" is the numerical index of the column.

    Args:
        values (iterable): The values of the new column
        df (DataFrame, optional): The DataFrame to update.
          If no DataFrame is passed, one is created by default.

    Returns:
        DataFrame
    """
    # Update the function to create a default DataFrame
    if df is None:
        df = pandas.DataFrame()
    df['col_{}'.format(len(df.columns))] = values
    return df

When you need to set a mutable variable as a default argument, always use None and then set the value in the body of the function. This prevents unexpected behavior like adding multiple columns if you call the function more than once.

Context Managers

Using context managers

We are working on a natural language processing project to determine what makes great writers so great. Our current hypothesis is that great writers talk about cats a lot. To prove it, we will count the number of times the word "cat" appears in “Alice’s Adventures in Wonderland” by Lewis Carroll. We have already downloaded a text file, alice.txt, with the entire contents of this great book.

# Open "alice.txt" and assign the file to "file"
with open('alice.txt', encoding="utf8") as file:
    text = file.read()

n = 0
for word in text.split():
    if word.lower() in ['cat', 'cats']:
        n += 1

print('Lewis Carroll uses the word "cat" {} times'.format(n))
Lewis Carroll uses the word "cat" 24 times

The speed of cats

We’re working on a new web service that processes Instagram feeds to identify which pictures contain cats (don’t ask why – it’s the internet). The code that processes the data is slower than we would like it to be, so we are working on tuning it up to run faster. Given an image, image, we have two functions that can process it:

  • process_with_numpy(image)
  • process_with_pytorch(image)

Our colleague wrote a context manager, timer(), that will print out how long the code inside the context block takes to run. She is suggesting we use it to see which of the two options is faster. Time each function to determine which one to use in your web service.

def get_image_from_instagram():
    return np.random.rand(84, 84)

Writing context managers

The timer() context manager

A colleague of ours is working on a web service that processes Instagram photos. Customers are complaining that the service takes too long to identify whether or not an image has a cat in it, so our colleague has come to us for help. We decided to write a context manager that they can use to time how long their functions take to run.

@contextlib.contextmanager
def timer():
    """Time how long code in the context block takes to run."""
    t0 = time.time()
    try:
        yield
    except:
        raise
    finally:
        t1 = time.time()
        print('Elapsed: {:.2f} seconds'.format(t1 - t0))

our colleague can now use our timer() context manager to figure out which of their functions is running too slow. The three elements of a context manager are all here: * a function definition, * a yield statement, and the * @contextlib.contextmanager decorator. timer() is a context manager that does not return an explicit value, so yield is written by itself without specifying anything to return.

def process_with_numpy(p):
    _process_pic(0.1521)
def _process_pic(n_sec):
    print('Processing', end='', flush=True)
    for i in range(10):
        print('.', end='' if i < 9 else 'done!\n', flush=True)
        time.sleep(n_sec)
def process_with_pytorch(p):
    _process_pic(0.0328)
image = get_image_from_instagram()

# Time how long process_with_numpy(image) takes to run
with timer():
    print('Numpy version')
    process_with_numpy(image)

# Time how long process_with_pytorch(image) takes to run
with timer():
    print('Pytorch version')
    process_with_pytorch(image)
Numpy version
Processing..........done!
Elapsed: 1.66 seconds
Pytorch version
Processing..........done!
Elapsed: 0.49 seconds

Now that we know the pytorch version is faster, we can use it in our web service to ensure our users get the rapid response time they expect.

There is no as <variable name> at the end of the with statement in timer() context manager. That is because timer() is a context manager that does not return a value, so the as <variable name> at the end of the with statement isn’t necessary.

A read-only open() context manager

We have a bunch of data files for our next deep learning project that took us months to collect and clean. It would be terrible if we accidentally overwrote one of those files when trying to read it in for training, so we decided to create a read-only version of the open() context manager to use in our project.

The regular open() context manager:

  • takes a filename and a mode ('r' for read, 'w' for write, or 'a' for append)
  • opens the file for reading, writing, or appending
  • sends control back to the context, along with a reference to the file
  • waits for the context to finish
  • and then closes the file before exiting

Our context manager will do the same thing, except it will only take the filename as an argument and it will only open the file for reading.

@contextlib.contextmanager
def open_read_only(filename):
    """Open a file in read-only mode.

    Args:
        filename (str): The location of the file to read

    Yields:
    file object
    """
    read_only_file = open(filename, mode='r')
    # Yield read_only_file so it can be assigned to my_file
    yield read_only_file
    # Close read_only_file
    read_only_file.close()

with open_read_only('my_file.txt') as my_file:
    print(my_file.read())
    Congratulations! You wrote a context manager that acts like "open()" but operates in read-only mode!

Scraping the NASDAQ

Training deep neural nets is expensive! We might as well invest in NVIDIA stock since we’re spending so much on GPUs. To pick the best time to invest, we are going to collect and analyze some data on how their stock is doing. The context manager stock('NVDA') will connect to the NASDAQ and return an object that you can use to get the latest price by calling its .price() method.

You want to connect to stock(‘NVDA’) and record 10 timesteps of price data by writing it to the file NVDA.txt

class MockStock:
    def __init__(self, loc, scale):
        self.loc = loc
        self.scale = scale
        self.recent = list(np.random.laplace(loc, scale, 2))
    def price(self):
        sign = np.sign(self.recent[1] - self.recent[0])
        # 70% chance of going same direction
        sign = 1 if sign == 0 else (sign if np.random.rand() > 0.3 else -1 * sign)
        new = self.recent[1] + sign * np.random.rand() / 10.0
        self.recent = [self.recent[1], new]
        return new
@contextlib.contextmanager
def stock(symbol):
    base = 140.00
    scale = 1.0
    mock = MockStock(base, scale)
    print('Opening stock ticker for {}'.format(symbol))
    yield mock
    print('Closing stock ticker')
# Use the "stock('NVDA')" context manager
# and assign the result to the variable "nvda"
with stock('NVDA') as nvda:
    # Open "NVDA.txt" for writing as f_out
    with open('NVDA.txt', 'w') as f_out:
        for _ in range(10):
              value = nvda.price()
              print('Logging ${:.2f} for NVDA'.format(value))
              f_out.write('{:.2f}\n'.format(value))
Opening stock ticker for NVDA
Logging $140.25 for NVDA
Logging $140.17 for NVDA
Logging $140.12 for NVDA
Logging $140.04 for NVDA
Logging $140.01 for NVDA
Logging $139.93 for NVDA
Logging $139.86 for NVDA
Logging $139.80 for NVDA
Logging $139.89 for NVDA
Logging $139.89 for NVDA
Closing stock ticker

Now we can monitor the NVIDIA stock price and decide when is the exact right time to buy. Nesting context managers like this allows us to connect to the stock market (the CONNECT/DISCONNECT pattern) and write to a file (the OPEN/CLOSE pattern) at the same time.

Changing the working directory

We are using an open-source library that lets us train deep neural networks on our data. Unfortunately, during training, this library writes out checkpoint models (i.e., models that have been trained on a portion of the data) to the current working directory. We found that behavior frustrating because we don’t want to have to launch the script from the directory where the models will be saved.

We decided that one way to fix this is to write a context manager that changes the current working directory, lets us build our models, and then resets the working directory to its original location. We’ll want to be sure that any errors that occur during model training don’t prevent you from resetting the working directory to its original location.

def in_dir(directory):
    """Change current working directory to `directory`, allow the user to run some code, and change back.

    Args:
        directory (str): The path to a directory to work in.
    """
    current_dir = os.getcwd()
    os.chdir(directory)

    # Add code that lets you handle errors
    try:
        yield
    # Ensure the directory is reset,
    # whether there was an error or not
    finally:
        os.chdir(current_dir)

Now, even if someone writes buggy code when using our context manager, we will be sure to change the current working directory back to what it was when they called in_dir(). This is important to do because our users might be relying on their working directory being what it was when they started the script. in_dir() is a great example of the CHANGE/RESET pattern that indicates you should use a context manager.

Decorators

Decorators are an extremely powerful concept in Python. They allow us to modify the behavior of a function without changing the code of the function itself.

Functions are objects

Building a command line data app

We are building a command line tool that lets a user interactively explore a data set. We’ve defined four functions: mean(), std(), minimum(), and maximum() that users can call to analyze their data. users can call any of these functions by typing the function name at the input prompt.

Note: The function get_user_input() in this exercise is a mock version of asking the user to enter a command. It randomly returns one of the four function names. In real life, we would ask for input and wait until the user entered a value.

import random
def get_user_input(prompt='Type a command: '):
    command = random.choice(['mean', 'std', 'minimum', 'maximum'])
    print(prompt)
    print('> {}'.format(command))
    return command
def mean(data):
    print(data.mean())
def std(data):
    print(data.std())
def minimum(data):
    print(data.min())
def maximum(data):
    print(data.max())
def load_data():
    df = pd.DataFrame()
    df['height'] = [72.1, 69.8, 63.2, 64.7]
    df['weight'] = [198, 204, 164, 238]
    return df
# Add the missing function references to the function map
function_map = {
  'mean': mean,
  'std': std,
  'minimum': minimum,
  'maximum': maximum
}

data = load_data()
print(data)

func_name = get_user_input()

# Call the chosen function and pass "data" as an argument
function_map[func_name](data)
   height  weight
0    72.1     198
1    69.8     204
2    63.2     164
3    64.7     238
Type a command: 
> std
height     4.194043
weight    30.309514
dtype: float64

Reviewing our co-worker’s code

Our co-worker is asking us to review some code that they’ve written and give them some tips on how to get it ready for production. We know that having a docstring is considered best practice for maintainable, reusable functions, so as a sanity check we decided to run this has_docstring() function on all of their functions.

def has_docstring(func):
    """Check to see if the function 
    `func` has a docstring.

    Args:
        func (callable): A function.

    Returns:
        bool
    """
    return func.__doc__ is not None
def load_and_plot_data(filename):
    """Load a data frame and plot each column.
  
    Args:
        filename (str): Path to a CSV file of data.
  
    Returns:
        pandas.DataFrame
    """
    df = pd.load_csv(filename, index_col=0)
    df.hist()
    return df
# Call has_docstring() on the load_and_plot_data() function
ok = has_docstring(load_and_plot_data)

if not ok:
    print("load_and_plot_data() doesn't have a docstring!")
else:
    print("load_and_plot_data() looks ok")
load_and_plot_data() looks ok
def as_2D(arr):
    """Reshape an array to 2 dimensions"""
    return np.array(arr).reshape(1, -1)
# Call has_docstring() on the as_2D() function
ok = has_docstring(as_2D)

if not ok:
    print("as_2D() doesn't have a docstring!")
else:
    print("as_2D() looks ok")
as_2D() looks ok
def log_product(arr):
    return np.exp(np.sum(np.log(arr)))
# Call has_docstring() on the log_product() function
ok = has_docstring(log_product)

if not ok:
    print("log_product() doesn't have a docstring!")
else:
    print("log_product() looks ok")
log_product() doesn't have a docstring!

Returning functions for a math game

We are building an educational math game where the player enters a math term, and our program returns a function that matches that term. For instance, if the user types “add”, our program returns a function that adds two numbers. So far we’ve only implemented the “add” function. Now we want to include a “subtract” function.

def create_math_function(func_name):
    if func_name == 'add':
        def add(a, b):
            return a + b
        return add
    elif func_name == 'subtract':
        # Define the subtract() function
        def subtract(a,b):
            return a-b
        return subtract
    else:
        print("I don't know that one")
    
add = create_math_function('add')
print('5 + 2 = {}'.format(add(5, 2)))

subtract = create_math_function('subtract')
print('5 - 2 = {}'.format(subtract(5, 2)))
5 + 2 = 7
5 - 2 = 3

Scope

Modifying variables outside local scope

call_count = 0

def my_function():
    # Use a keyword that lets us update call_count 
    global call_count
    call_count += 1
  
    print("You've called my_function() {} times!".format(
        call_count
    ))
  
for _ in range(20):
    my_function()
You've called my_function() 1 times!
You've called my_function() 2 times!
You've called my_function() 3 times!
You've called my_function() 4 times!
You've called my_function() 5 times!
You've called my_function() 6 times!
You've called my_function() 7 times!
You've called my_function() 8 times!
You've called my_function() 9 times!
You've called my_function() 10 times!
You've called my_function() 11 times!
You've called my_function() 12 times!
You've called my_function() 13 times!
You've called my_function() 14 times!
You've called my_function() 15 times!
You've called my_function() 16 times!
You've called my_function() 17 times!
You've called my_function() 18 times!
You've called my_function() 19 times!
You've called my_function() 20 times!
def read_files():
    file_contents = None
  
    def save_contents(filename):
        # Add a keyword that lets us modify file_contents
        nonlocal file_contents
        if file_contents is None:
            file_contents = []
        with open(filename) as fin:
            file_contents.append(fin.read())
      
    for filename in ['1984.txt', 'MobyDick.txt', 'CatsEye.txt']:
        save_contents(filename)
    
    return file_contents

print('\n'.join(read_files()))
It was a bright day in April, and the clocks were striking thirteen.
Call me Ishmael.
Time is not a line but a dimension, like the dimensions of space.
def wait_until_done():
    def check_is_done():
        # Add a keyword so that wait_until_done() 
        # doesn't run forever
        global done
        if random.random() < 0.1:
            done = True
      
    while not done:
        check_is_done()

done = False
wait_until_done()

print('Work done? {}'.format(done))
Work done? True

Closures

Checking for closures

def return_a_func(arg1, arg2):
    def new_func():
        print('arg1 was {}'.format(arg1))
        print('arg2 was {}'.format(arg2))
    return new_func
    
my_func = return_a_func(2, 17)

print(my_func.__closure__ is not None)
print(len(my_func.__closure__) == 2)

# Get the values of the variables in the closure
closure_values = [
  my_func.__closure__[i].cell_contents for i in range(2)
]
print(closure_values == [2, 17])
True
True
True

Closures keep your values safe

def my_special_function():
    print('You are running my_special_function()')

def get_new_func(func):
    def call_func():
        func()
    return call_func

new_func = get_new_func(my_special_function)

# Redefine my_special_function() to just print "hello"
def my_special_function():
    print("hello")

new_func()
You are running my_special_function()
def my_special_function():
    print('You are running my_special_function()')
  
def get_new_func(func):
    def call_func():
        func()
    return call_func

new_func = get_new_func(my_special_function)

# Delete my_special_function()
del(my_special_function)

new_func()
You are running my_special_function()
def my_special_function():
    print('You are running my_special_function()')
  
def get_new_func(func):
    def call_func():
        func()
    return call_func

# Overwrite `my_special_function` with the new function
my_special_function = get_new_func(my_special_function)

my_special_function()
You are running my_special_function()

Decorators

Using decorator syntax

print_args prints out all of the arguments and their values any time a function that it is decorating gets called.

def print_args(func):
    sig = inspect.signature(func)
    def wrapper(*args, **kwargs):
        bound = sig.bind(*args, **kwargs).arguments
        str_args = ', '.join(['{}={}'.format(k, v) for k, v in bound.items()])
        print('{} was called with {}'.format(func.__name__, str_args))
        return func(*args, **kwargs)
    return wrapper
def my_function(a, b, c):
    print(a + b + c)

# Decorate my_function() with the print_args() decorator
my_function = print_args(my_function)

my_function(1, 2, 3)
my_function was called with a=1, b=2, c=3
6
# Decorate my_function() with the print_args() decorator
@print_args
def my_function(a, b, c):
  print(a + b + c)

my_function(1, 2, 3)
my_function was called with a=1, b=2, c=3
6

even though decorators are functions themselves, when you use decorator syntax with the@ symbol we do not include the parentheses after the decorator name.

Defining a decorator

Our buddy has been working on a decorator that prints a “before” message before the decorated function is called and prints an “after” message after the decorated function is called. They are having trouble remembering how wrapping the decorated function is supposed to work.

def print_before_and_after(func):
  def wrapper(*args):
    print('Before {}'.format(func.__name__))
    # Call the function being decorated with *args
    func(*args)
    print('After {}'.format(func.__name__))
  # Return the nested function
  return wrapper

@print_before_and_after
def multiply(a, b):
  print(a * b)

multiply(5, 10)
Before multiply
50
After multiply

Counter

how many times each of the functions in it gets called in a web app

def counter(func):
  def wrapper(*args, **kwargs):
    wrapper.count += 1
    # Call the function being decorated and return the result
    return func(*args, **kwargs)
  wrapper.count = 0
  # Return the new decorated function
  return wrapper

# Decorate foo() with the counter() decorator
@counter
def foo():
  print('calling foo()')
  
foo()
foo()

print('foo() was called {} times.'.format(foo.count))
calling foo()
calling foo()
foo() was called 2 times.

Decorators and metadata

Preserving docstrings when decorating functions

def add_hello(func):
  def wrapper(*args, **kwargs):
    print('Hello')
    return func(*args, **kwargs)
  return wrapper

# Decorate print_sum() with the add_hello() decorator
@add_hello
def print_sum(a, b):
  """Adds two numbers and prints the sum"""
  print(a + b)
  
print_sum(10, 20)
print(print_sum.__doc__)
Hello
30
None
def add_hello(func):
  # Add a docstring to wrapper
  def wrapper(*args, **kwargs):
    """Print 'hello' and then call the decorated function."""
    print('Hello')
    return func(*args, **kwargs)
  return wrapper

@add_hello
def print_sum(a, b):
  """Adds two numbers and prints the sum"""
  print(a + b)
  
print_sum(10, 20)
print(print_sum.__doc__)
Hello
30
Print 'hello' and then call the decorated function.
from functools import wraps

def add_hello(func):
  # Decorate wrapper() so that it keeps func()'s metadata
  @wraps(func)
  def wrapper(*args, **kwargs):
    """Print 'hello' and then call the decorated function."""
    print('Hello')
    return func(*args, **kwargs)
  return wrapper
  
@add_hello
def print_sum(a, b):
  """Adds two numbers and prints the sum"""
  print(a + b)
  
print_sum(10, 20)
print(print_sum.__doc__)
Hello
30
Adds two numbers and prints the sum

Measuring decorator overhead

def check_inputs(a, *args, **kwargs):
  for value in a:
    time.sleep(0.01)
  print('Finished checking inputs')
def check_outputs(a, *args, **kwargs):
  for value in a:
    time.sleep(0.01)
  print('Finished checking outputs')
def check_everything(func):
  @wraps(func)
  def wrapper(*args, **kwargs):
    check_inputs(*args, **kwargs)
    result = func(*args, **kwargs)
    check_outputs(result)
    return result
  return wrapper
@check_everything
def duplicate(my_list):
  """Return a new list that repeats the input twice"""
  return my_list + my_list

t_start = time.time()
duplicated_list = duplicate(list(range(50)))
t_end = time.time()
decorated_time = t_end - t_start

t_start = time.time()
# Call the original function instead of the decorated one
duplicated_list = duplicate.__wrapped__(list(range(50)))
t_end = time.time()
undecorated_time = t_end - t_start

print('Decorated time: {:.5f}s'.format(decorated_time))
print('Undecorated time: {:.5f}s'.format(undecorated_time))
Finished checking inputs
Finished checking outputs
Decorated time: 2.40551s
Undecorated time: 0.00000s

Decorators that take arguments

Run_n_times()

def run_n_times(n):
  """Define and return a decorator"""
  def decorator(func):
    def wrapper(*args, **kwargs):
      for i in range(n):
        func(*args, **kwargs)
    return wrapper
  return decorator
# Make print_sum() run 10 times with the run_n_times() decorator
@run_n_times(10)
def print_sum(a, b):
  print(a + b)
  
print_sum(15, 20)
35
35
35
35
35
35
35
35
35
35
# Use run_n_times() to create the run_five_times() decorator
run_five_times = run_n_times(5)

@run_five_times
def print_sum(a, b):
  print(a + b)
  
print_sum(4, 100)
104
104
104
104
104
# Modify the print() function to always run 20 times
print = run_n_times(20)(print)

print('What is happening?!?!')
What is happening?!?!
What is happening?!?!
What is happening?!?!
What is happening?!?!
What is happening?!?!
What is happening?!?!
What is happening?!?!
What is happening?!?!
What is happening?!?!
What is happening?!?!
What is happening?!?!
What is happening?!?!
What is happening?!?!
What is happening?!?!
What is happening?!?!
What is happening?!?!
What is happening?!?!
What is happening?!?!
What is happening?!?!
What is happening?!?!
print = run_n_times(1)(print)

HTML Generator

a script that generates HTML for a webpage on the fly.

def html(open_tag, close_tag):
  def decorator(func):
    @wraps(func)
    def wrapper(*args, **kwargs):
      msg = func(*args, **kwargs)
      return '{}{}{}'.format(open_tag, msg, close_tag)
    # Return the decorated function
    return wrapper
  # Return the decorator
  return decorator
# Make hello() return bolded text
@html("<b>", "</b>")
def hello(name):
  return 'Hello {}!'.format(name)
  
print(hello('Alice'))
<b>Hello Alice!</b>
%%html

<b>Hello Alice!</b>
Hello Alice!
# Make goodbye() return italicized text
@html("<i>", "</i>")
def goodbye(name):
  return 'Goodbye {}.'.format(name)
  
print(goodbye('Alice'))
<i>Goodbye Alice.</i>
%%html
<i>Goodbye Alice.</i>
Goodbye Alice.
# Wrap the result of hello_goodbye() in <div> and </div>
@html("<div>", "</div")
def hello_goodbye(name):
  return '\n{}\n{}\n'.format(hello(name), goodbye(name))
  
print(hello_goodbye('Alice'))
<div>
<b>Hello Alice!</b>
<i>Goodbye Alice.</i>
</div
%%html
<div>
<b>Hello Alice!</b>
<i>Goodbye Alice.</i>
</div
Hello Alice! Goodbye Alice.

Timeout(): a real world example

Tag your functions

Tagging something means that you have given that thing one or more strings that act as labels. For instance, we often tag emails or photos so that we can search for them later. You’ve decided to write a decorator that will let you tag your functions with an arbitrary list of tags. You could use these tags for many things:

  • Adding information about who has worked on the function, so a user can look up who to ask if they run into trouble using it.
  • Labeling functions as “experimental” so that users know that the inputs and outputs might change in the future.
  • Marking any functions that you plan to remove in a future version of the code.
  • Etc.
def tag(*tags):
  # Define a new decorator, named "decorator", to return
  def decorator(func):
    # Ensure the decorated function keeps its metadata
    @wraps(func)
    def wrapper(*args, **kwargs):
      # Call the function being decorated and return the result
      return func(*args, **kwargs)
    wrapper.tags = tags
    return wrapper
  # Return the new decorator
  return decorator

@tag('test', 'this is a tag')
def foo():
  pass

print(foo.tags)
('test', 'this is a tag')

Check the return type

def returns(return_type):
  # Complete the returns() decorator
  def decorator(func):
    def wrapper(*args, **kwargs):
      result = func(*args, **kwargs)
      assert(type(result) == return_type)
      return result
    return wrapper
  return decorator
  
@returns(dict)
def foo(value):
  return value

try:
  print(foo([1,2,3]))
except AssertionError:
  print('foo() did not return a dict!')
foo() did not return a dict!